{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Notebook to train DeepMEL\n",
    "#### Following codes in this notebook were run using conda environments\n",
    "#### Here are the used packages and their version"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%bash\n",
    "#cpu\n",
    "conda create --name DeepMEL_conda_env_cpu python=3.6 tensorflow=1.14.0 keras=2.2.4\n",
    "conda activate DeepMEL_conda_env_cpu\n",
    "conda install numpy=1.16.2 matplotlib=3.1.1 shap=0.29.3 ipykernel=5.1.2\n",
    "\n",
    "#gpu\n",
    "conda create --name DeepMEL_conda_env_gpu python=3.6 tensorflow-gpu=1.14.0 keras-gpu=2.2.4\n",
    "conda activate DeepMEL_conda_env_gpu\n",
    "conda install numpy=1.16.2  matplotlib=3.1.1 shap=0.29.3 ipykernel=5.1.2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Loading necessary packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import optparse\n",
    "from array import *\n",
    "\n",
    "import tensorflow\n",
    "import numpy as np\n",
    "\n",
    "import matplotlib\n",
    "#matplotlib.use('pdf')\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import sklearn\n",
    "from sklearn.utils import class_weight, shuffle\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import average_precision_score, roc_auc_score\n",
    "\n",
    "from keras.models import Sequential\n",
    "from keras.models import model_from_json\n",
    "from keras.layers.core import Dense, Dropout, Activation, Flatten\n",
    "from keras.layers.convolutional import Conv1D, MaxPooling1D\n",
    "from keras.layers import Bidirectional, Concatenate, PReLU \n",
    "from keras.optimizers import RMSprop, Adam\n",
    "from keras.callbacks import ModelCheckpoint, EarlyStopping\n",
    "from keras import regularizers\n",
    "from keras.layers.wrappers import Bidirectional, TimeDistributed\n",
    "from keras.layers.recurrent import LSTM\n",
    "from keras.layers import Layer, average, Input\n",
    "from keras.models import Model\n",
    "from keras.utils import plot_model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Defining necessary functitons"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_output(input_layer, hidden_layers):\n",
    "    output = input_layer\n",
    "    for hidden_layer in hidden_layers: \n",
    "        output = hidden_layer(output)\n",
    "    return output\n",
    "\n",
    "def build_model():\n",
    "    forward_input = Input(shape=seq_shape)\n",
    "    reverse_input = Input(shape=seq_shape)\n",
    "\n",
    "    hidden_layers = [\n",
    "        Conv1D(128, kernel_size=20, padding=\"valid\", activation='relu', kernel_initializer='random_uniform'),\n",
    "        MaxPooling1D(pool_size=10, strides=10, padding='valid'),\n",
    "        Dropout(0.2),\n",
    "        TimeDistributed(Dense(128, activation='relu')),\n",
    "        Bidirectional(LSTM(128, dropout=0.1, recurrent_dropout=0.1, return_sequences=True)),\n",
    "        Dropout(0.2),\n",
    "        Flatten(),\n",
    "        Dense(256, activation='relu'),\n",
    "        Dropout(0.4),\n",
    "        Dense(len(selected_classes), activation='sigmoid')]\n",
    "    forward_output = get_output(forward_input, hidden_layers)     \n",
    "    reverse_output = get_output(reverse_input, hidden_layers)\n",
    "    output = average([forward_output, reverse_output])\n",
    "    model = Model(input=[forward_input, reverse_input], output=output)\n",
    "\n",
    "    model.summary()\n",
    "    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])\n",
    "    return model\n",
    "\n",
    "def readfile(filename):\n",
    "    ids = []\n",
    "    ids_d = {}\n",
    "    seqs = {}\n",
    "    classes = {}\n",
    "    f = open(filename, 'r')\n",
    "    lines = f.readlines()\n",
    "    f.close()\n",
    "    seq = []\n",
    "    for line in lines:\n",
    "        if line[0] == '>':\n",
    "            ids.append(line[1:].rstrip('\\n'))\n",
    "            if line[1:].rstrip('\\n').split('_')[0] not in seqs:\n",
    "                seqs[line[1:].rstrip('\\n').split('_')[0]] = []\n",
    "            if line[1:].rstrip('\\n').split('_')[0] not in ids_d:\n",
    "                ids_d[line[1:].rstrip('\\n').split('_')[0]] = line[1:].rstrip('\\n').split('_')[0]\n",
    "            if line[1:].rstrip('\\n').split('_')[0] not in classes:\n",
    "                classes[line[1:].rstrip('\\n').split('_')[0]] = np.zeros(NUM_CLASSES)\n",
    "            classes[line[1:].rstrip('\\n').split('_')[0]][int(line[1:].rstrip('\\n').split('_')[1])-1] = 1        \n",
    "            if seq != []: seqs[ids[-2].split('_')[0]]= (\"\".join(seq))\n",
    "            seq = []\n",
    "        else:\n",
    "            seq.append(line.rstrip('\\n').upper())\n",
    "    if seq != []:\n",
    "        seqs[ids[-1].split('_')[0]]=(\"\".join(seq))\n",
    "    return ids,ids_d,seqs,classes\n",
    "\n",
    "def one_hot_encode_along_row_axis(sequence):\n",
    "    to_return = np.zeros((1,len(sequence),4), dtype=np.int8)\n",
    "    seq_to_one_hot_fill_in_array(zeros_array=to_return[0],\n",
    "                                 sequence=sequence, one_hot_axis=1)\n",
    "    return to_return\n",
    "\n",
    "def seq_to_one_hot_fill_in_array(zeros_array, sequence, one_hot_axis):\n",
    "    assert one_hot_axis==0 or one_hot_axis==1\n",
    "    if (one_hot_axis==0):\n",
    "        assert zeros_array.shape[1] == len(sequence)\n",
    "    elif (one_hot_axis==1): \n",
    "        assert zeros_array.shape[0] == len(sequence)\n",
    "    for (i,char) in enumerate(sequence):\n",
    "        if (char==\"A\" or char==\"a\"):\n",
    "            char_idx = 0\n",
    "        elif (char==\"C\" or char==\"c\"):\n",
    "            char_idx = 1\n",
    "        elif (char==\"G\" or char==\"g\"):\n",
    "            char_idx = 2\n",
    "        elif (char==\"T\" or char==\"t\"):\n",
    "            char_idx = 3\n",
    "        elif (char==\"N\" or char==\"n\"):\n",
    "            continue\n",
    "        else:\n",
    "            raise RuntimeError(\"Unsupported character: \"+str(char))\n",
    "        if (one_hot_axis==0):\n",
    "            zeros_array[char_idx,i] = 1\n",
    "        elif (one_hot_axis==1):\n",
    "            zeros_array[i,char_idx] = 1\n",
    "\n",
    "def create_plots(history):\n",
    "    plt.plot(history.history['acc'])\n",
    "    plt.plot(history.history['val_acc'])\n",
    "    plt.title('model accuracy')\n",
    "    plt.ylabel('accuracy')\n",
    "    plt.xlabel('epoch')\n",
    "    plt.legend(['train', 'test'], loc='upper left')\n",
    "    plt.savefig(foldername + 'accuracy.png')\n",
    "    plt.clf()\n",
    "\n",
    "    plt.plot(history.history['loss'])\n",
    "    plt.plot(history.history['val_loss'])\n",
    "    plt.title('model loss')\n",
    "    plt.ylabel('loss')\n",
    "    plt.xlabel('epoch')\n",
    "    plt.legend(['train', 'test'], loc='upper left')\n",
    "    plt.savefig(foldername + 'loss.png')\n",
    "    plt.clf()\n",
    "    \n",
    "def json_hdf5_to_model(json_filename, hdf5_filename):  \n",
    "    with open(json_filename, 'r') as f:\n",
    "        model = model_from_json(f.read())\n",
    "    model.load_weights(hdf5_filename)\n",
    "    return model\n",
    "\n",
    "def loc_to_model_loss(loc):\n",
    "    return json_hdf5_to_model(loc + 'model.json', loc + 'model_best_loss.hdf5')\n",
    "\n",
    "def shuffle_label(label):\n",
    "    for i in range(len(label.T)):\n",
    "        label.T[i] = shuffle(label.T[i])\n",
    "    return label\n",
    "\n",
    "def calculate_roc_pr(score, label):\n",
    "    output = np.zeros((len(label.T), 2))\n",
    "    for i in range(len(label.T)):\n",
    "        roc_ = roc_auc_score(label.T[i], score.T[i])\n",
    "        pr_ = average_precision_score(label.T[i], score.T[i])\n",
    "        output[i] = [roc_, pr_]\n",
    "    return output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Preparing the input data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NUM_CLASSES = 24\n",
    "selected_classes = np.array([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24])-1\n",
    "SEQ_LEN = 500\n",
    "SEQ_DIM = 4\n",
    "seq_shape = (SEQ_LEN, SEQ_DIM)\n",
    "EPOCH = 2\n",
    "BATCH = 128\n",
    "\n",
    "foldername = '/results/'\n",
    "train_filename = '/summits_to_topics_wochr2_wochr11.fa'\n",
    "valid_filename = '/summits_to_topics_chr11.fa'\n",
    "test_filename = '/summits_to_topics_chr2.fa'\n",
    "\n",
    "PATH_TO_SAVE_ARC = foldername + 'model.json'\n",
    "PATH_TO_SAVE_BEST_LOST_WEIGHTS = foldername + 'model_best_loss.hdf5'\n",
    "PATH_TO_SAVE_BEST_ACC_WEIGHTS = foldername + 'model_best_acc.hdf5'\n",
    "PATH_TO_SAVE_END_WEIGHTS = foldername + 'model_end.hdf5'\n",
    "\n",
    "print(\"Prepare input...\")\n",
    "train_ids, train_ids_d, train_seqs, train_classes = readfile(train_filename)\n",
    "X_train = np.array([one_hot_encode_along_row_axis(train_seqs[id]) for id in train_ids_d]).squeeze(axis=1)\n",
    "y_train = np.array([train_classes[id] for id in train_ids_d])\n",
    "y_train = y_train[:,selected_classes]\n",
    "X_train = X_train[y_train.sum(axis=1)>0]\n",
    "y_train = y_train[y_train.sum(axis=1)>0]\n",
    "X_train_rc = X_train[:,::-1,::-1]\n",
    "train_data = [X_train, X_train_rc]\n",
    "\n",
    "valid_ids, valid_ids_d, valid_seqs, valid_classes = readfile(valid_filename)\n",
    "X_valid = np.array([one_hot_encode_along_row_axis(valid_seqs[id]) for id in valid_ids_d]).squeeze(axis=1)\n",
    "y_valid = np.array([valid_classes[id] for id in valid_ids_d])\n",
    "y_valid = y_valid[:,selected_classes]\n",
    "X_valid = X_valid[y_valid.sum(axis=1)>0]\n",
    "y_valid = y_valid[y_valid.sum(axis=1)>0]\n",
    "X_valid_rc = X_valid[:,::-1,::-1]\n",
    "valid_data = [X_valid, X_valid_rc]\n",
    "\n",
    "test_ids, test_ids_d, test_seqs, test_classes = readfile(test_filename)\n",
    "X_test = np.array([one_hot_encode_along_row_axis(test_seqs[id]) for id in test_ids_d]).squeeze(axis=1)\n",
    "y_test = np.array([test_classes[id] for id in test_ids_d])\n",
    "y_test = y_test[:,selected_classes]\n",
    "X_test = X_test[y_test.sum(axis=1)>0]\n",
    "y_test = y_test[y_test.sum(axis=1)>0]\n",
    "X_test_rc = X_test[:,::-1,::-1]\n",
    "test_data = [X_test, X_test_rc]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Training the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Compile model...\")\n",
    "model = build_model()\n",
    "\n",
    "model_json = model.to_json()\n",
    "with open(PATH_TO_SAVE_ARC, \"w\") as json_file:\n",
    "    json_file.write(model_json)\n",
    "print(\"Model architecture saved to..\", PATH_TO_SAVE_ARC)\n",
    "\n",
    "checkpoint1 = ModelCheckpoint(PATH_TO_SAVE_BEST_LOST_WEIGHTS, monitor='val_loss', verbose=1, save_best_only=True, mode='min')\n",
    "checkpoint2 = ModelCheckpoint(PATH_TO_SAVE_BEST_ACC_WEIGHTS, monitor='val_acc', verbose=1, save_best_only=True, mode='max')\n",
    "checkpoint3 = EarlyStopping(monitor='val_loss', patience=6)\n",
    "callbacks_list = [checkpoint1, checkpoint2, checkpoint3]\n",
    "\n",
    "print(\"Train model...\")\n",
    "history = model.fit( train_data, y_train, nb_epoch=EPOCH, batch_size=BATCH, shuffle=True, validation_data=(test_data, y_test),  verbose=1, callbacks= callbacks_list)\n",
    "create_plots(history)\n",
    "model.save_weights(PATH_TO_SAVE_END_WEIGHTS)\n",
    "print(\"Model weights saved to..\", PATH_TO_SAVE_END_WEIGHTS)\n",
    "plot_model(model, to_file=foldername + 'model.png')\n",
    "\n",
    "score, acc = model.evaluate(X_test, y_test, batch_size=BATCH)\n",
    "print('Test score:', score)\n",
    "print('Validation accuracy:', acc)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
